Skip to content

SamudrAce#782

Open
NickGeneva wants to merge 22 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/samudrace
Open

SamudrAce#782
NickGeneva wants to merge 22 commits intoNVIDIA:mainfrom
NickGeneva:ngeneva/samudrace

Conversation

@NickGeneva
Copy link
Copy Markdown
Collaborator

@NickGeneva NickGeneva commented Mar 30, 2026

Earth2Studio Pull Request

Description

  • Adds a new SamudrACE prognostic model wrapper that couples the ACE2 atmosphere model with the Samudra ocean model via FME's CoupledStepper, enabling coupled climate inference through Earth2Studio's standard run.deterministic workflow
  • Adds supporting SamudrACEData data source for fetching initial conditions and SamudrACELexicon for variable name mapping between E2S and FME conventions
  • Registers the samudrace optional dependency group (fme>=2025.10.0, pandas, scipy) with conflict declarations against atlas/fcn3/perturbation/sfno extras

Points for discussion

  • Forcing approach here vs ACEERA5
  • Lexicon names for ocean values
  • Model returning 6 hour time steps.
  • Is comparison results accurate enough to attribute floating point error

I believe I found the source of the delta, seems the land sea mask has 1 ULP difference which apparently has a some slight difference with how its loaded. Seems the ocean model is very sensitive to the mask. I checked the input tensors of all other values and this was the only delta. I'm calling it within machine precision.

Modified Files

  • earth2studio/data/__init__.py, earth2studio/lexicon/__init__.py, earth2studio/models/px/__init__.py — Register new classes
  • earth2studio/lexicon/base.py — Add ocean/coupled variable entries to E2STUDIO_VOCAB
  • earth2studio/data/ace2.py — Minor docstring line wrapping
  • pyproject.toml — Add samudrace optional extra with conflict declarations
  • tox.ini — Add samudrace tests to the test-px-models-ace2 environment
  • docs/modules/models_px.rst — Add SamudrACE to prognostic model docs index
  • CHANGELOG.md — Add entry

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.
  • Assess and address Greptile feedback (AI code review bot for guidance; use discretion, addressing all feedback is not required).

Dependencies

@NickGeneva
Copy link
Copy Markdown
Collaborator Author

Comparison scripts

Vanilla FME
import os
from pathlib import Path

from huggingface_hub import snapshot_download

from fme.coupled.inference.inference import main

import torch
import numpy as np

torch.manual_seed(0)
np.random.seed(0)


# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
HF_REPO = "allenai/SamudrACE-CM4-piControl"
CACHE_DIR = str(Path.cwd() / "fme_cache")

# Number of coupled steps (each = 20 atm 6h steps + 1 ocean 5-day step)
N_COUPLED_STEPS = 4 

OUTPUT_DIR = os.path.abspath("outputs/fme_reference_output")

print("Downloading SamudrACE files from HuggingFace...")
repo_dir = snapshot_download(repo_id=HF_REPO, cache_dir=CACHE_DIR)
print(f"  Local repo: {repo_dir}")

# The YAML uses relative paths, so we chdir into the snapshot directory.
original_cwd = os.getcwd()
os.chdir(repo_dir)

print(f"Running coupled inference ({N_COUPLED_STEPS} coupled steps)...")
main(
    yaml_config="inference_config.yaml",
    override_dotlist=[
        f"n_coupled_steps={N_COUPLED_STEPS}",
        f"coupled_steps_in_memory={min(N_COUPLED_STEPS, 20)}",
        f"experiment_dir={OUTPUT_DIR}",
    ],
)

os.chdir(original_cwd)
print(f"\nResults saved to {OUTPUT_DIR}/")
print("Done.")
E2S
import earth2studio.run as run
from earth2studio.data.samudrace import SamudrACEData
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import SamudrACE

import torch
import numpy as np

torch.manual_seed(0)
np.random.seed(0)

SCENARIO = "0311"
N_COUPLED_STEPS = 4
N_INNER_STEPS = 20  # 20 atm 6h steps per coupled step
TOTAL_ATM_STEPS = N_COUPLED_STEPS * N_INNER_STEPS  # 40 steps
# Use a proxy time that maps to Jan 1 for forcing lookup (day-of-year matching)
TIME = ["2000-01-01T00:00:00"]

OUTPUT_FILE = "outputs/e2s_reference_output.nc"

print("Loading SamudrACE via Earth2Studio...")
package = SamudrACE.load_default_package()
model = SamudrACE.load_model(package, forcing_scenario=SCENARIO)
data = SamudrACEData(ic_timestamp=f"{SCENARIO}-01-01T00:00:00")

print(f"Running deterministic forecast ({TOTAL_ATM_STEPS} steps)...")
io = run.deterministic(
    time=TIME,
    nsteps=TOTAL_ATM_STEPS,
    prognostic=model,
    data=data,
    io=NetCDF4Backend(file_name=OUTPUT_FILE, backend_kwargs={"mode": "w"}),
)

print(f"\nResults saved to {OUTPUT_FILE}")
print("Done.")

@NickGeneva
Copy link
Copy Markdown
Collaborator Author

NickGeneva commented Mar 30, 2026

Plot Script
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from earth2studio.lexicon.samudrace import SamudrACELexicon

# --------------------------------------------------------------------------- #
# Variable definitions for comparison
# --------------------------------------------------------------------------- #
# Each entry maps an E2S name → FME name, component, and plot styling.
VARIABLES = {
    "sst": {
        "e2s_name": "sst",
        "fme_name": "sst",
        "component": "ocean",
        "label": "Sea Surface Temperature [K]",
        "cmap": "RdYlBu_r",
        "vmin": 270,
        "vmax": 305,
    },
    "t2m": {
        "e2s_name": "t2m",
        "fme_name": "TMP2m",
        "component": "atmosphere",
        "label": "2m Temperature [K]",
        "cmap": "RdYlBu_r",
        "vmin": 200,
        "vmax": 320,
    },
    "sp": {
        "e2s_name": "sp",
        "fme_name": "PRESsfc",
        "component": "atmosphere",
        "label": "Surface Pressure [Pa]",
        "cmap": "viridis",
        "vmin": 50000,
        "vmax": 105000,
    },
    "zos": {
        "e2s_name": "zos",
        "fme_name": "zos",
        "component": "ocean",
        "label": "Sea Surface Height [m]",
        "cmap": "RdBu_r",
        "vmin": -2,
        "vmax": 2,
    },
}

# --------------------------------------------------------------------------- #
# Load datasets
# --------------------------------------------------------------------------- #
print("Loading datasets...")

# FME reference (autoregressive prediction files)
ref_atm_path = "outputs/fme_reference_output/atmosphere/autoregressive_predictions.nc"
ref_ocean_path = "outputs/fme_reference_output/ocean/autoregressive_predictions.nc"

# E2S reference (single NetCDF from run.deterministic)
e2s_path = "outputs/e2s_reference_output.nc"

ds_ref_atm = xr.open_dataset(ref_atm_path)
ds_ref_ocean = xr.open_dataset(ref_ocean_path)
ds_e2s = xr.open_dataset(e2s_path)

print(f"Ref atmosphere:  {dict(ds_ref_atm.dims)}, vars: {len(ds_ref_atm.data_vars)}")
print(
    f"Ref ocean:       {dict(ds_ref_ocean.dims)}, vars: {len(ds_ref_ocean.data_vars)}"
)
print(f"E2S output:      {dict(ds_e2s.dims)}, vars: {len(ds_e2s.data_vars)}")


# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #
# Number of atmosphere steps per coupled cycle (ocean updates once per cycle)
N_INNER_STEPS = 20


def get_ref_ds(component: str) -> xr.Dataset:
    """Return the FME reference dataset for a given component."""
    return ds_ref_atm if component == "atmosphere" else ds_ref_ocean


def get_ref_series(da: xr.DataArray) -> np.ndarray:
    """Extract a (n_steps, lat, lon) array from an FME reference DataArray.

    FME files have dims (sample, time, lat, lon).  We select sample=0 and
    return the time axis as the step dimension.
    """
    return da.isel(sample=0).values  # (time, lat, lon)


def get_e2s_series(da: xr.DataArray, component: str = "atmosphere") -> np.ndarray:
    """Extract a (n_steps, lat, lon) array from the E2S DataArray.

    E2S NetCDF4Backend output has dims (time, lead_time, lat, lon).
    We select time=0 and return the lead_time axis as the step dimension.

    The first lead_time entry (index 0) is the initial condition (lead_time=0h),
    while FME autoregressive_predictions.nc contains only predictions starting at
    the first 6h forecast step.  We therefore skip the IC so that both series
    are aligned by forecast lead time.

    For ocean variables, FME writes one prediction per coupled cycle (every
    N_INNER_STEPS atmosphere steps).  E2S writes the ocean state at every
    atmosphere step.  We sub-sample E2S at the coupled-cycle boundaries
    so the two series are comparable.
    """
    vals = da.isel(time=0).values[1:]  # skip IC; (n_prediction_steps, lat, lon)
    if component == "ocean":
        # Sub-sample at coupled-cycle boundaries: indices 19, 39, 59, ...
        ocean_indices = np.arange(N_INNER_STEPS - 1, len(vals), N_INNER_STEPS)
        vals = vals[ocean_indices]
    return vals


# --------------------------------------------------------------------------- #
# Figure 1: Global-mean timeseries comparison
# --------------------------------------------------------------------------- #
print("\nPlotting global-mean timeseries...")

fig, axes = plt.subplots(len(VARIABLES), 3, figsize=(18, 4 * len(VARIABLES)))
if len(VARIABLES) == 1:
    axes = axes[None, :]

for row, (var_key, var_info) in enumerate(VARIABLES.items()):
    fme_name = var_info["fme_name"]
    e2s_name = var_info["e2s_name"]
    component = var_info["component"]

    ds_ref = get_ref_ds(component)

    if fme_name not in ds_ref.data_vars:
        print(f"  WARNING: {fme_name} not in reference {component} dataset, skipping")
        continue
    if e2s_name not in ds_e2s.data_vars:
        print(f"  WARNING: {e2s_name} not in E2S dataset, skipping")
        continue

    ref_vals = get_ref_series(ds_ref[fme_name])  # (steps, lat, lon)
    e2s_vals = get_e2s_series(ds_e2s[e2s_name], component)  # (steps, lat, lon)

    ref_mean = np.nanmean(ref_vals, axis=(1, 2))
    e2s_mean = np.nanmean(e2s_vals, axis=(1, 2))

    # Align lengths
    n_common = min(len(ref_mean), len(e2s_mean))
    ref_mean = ref_mean[:n_common]
    e2s_mean = e2s_mean[:n_common]
    x_common = np.arange(n_common)

    # Plot 1: Reference
    ax = axes[row, 0]
    ax.plot(x_common, ref_mean, "b-", linewidth=1.5)
    ax.set_ylabel(var_info["label"], fontsize=10)
    if row == 0:
        ax.set_title("Reference (FME)", fontsize=12, fontweight="bold")
    ax.set_xlabel("Step")
    ax.grid(True, alpha=0.3)

    # Plot 2: E2S
    ax = axes[row, 1]
    ax.plot(x_common, e2s_mean, "r-", linewidth=1.5)
    if row == 0:
        ax.set_title("Earth2Studio", fontsize=12, fontweight="bold")
    ax.set_xlabel("Step")
    ax.grid(True, alpha=0.3)

    # Plot 3: Difference
    ax = axes[row, 2]
    diff = e2s_mean - ref_mean
    ax.plot(x_common, diff, "k-", linewidth=1.5)
    ax.axhline(y=0, color="gray", linestyle="--", alpha=0.5)
    if row == 0:
        ax.set_title("Difference (E2S - Ref)", fontsize=12, fontweight="bold")
    ax.set_xlabel("Step")
    ax.grid(True, alpha=0.3)

    rmse = np.sqrt(np.nanmean(diff**2))
    mae = np.nanmean(np.abs(diff))
    ax.text(
        0.02,
        0.95,
        f"RMSE={rmse:.4g}\nMAE={mae:.4g}",
        transform=ax.transAxes,
        fontsize=9,
        verticalalignment="top",
        bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
    )

plt.tight_layout()
ts_file = "compare_timeseries.png"
plt.savefig(ts_file, dpi=150, bbox_inches="tight")
print(f"Saved {ts_file}")
plt.close()


# --------------------------------------------------------------------------- #
# Figure 2: Spatial maps at selected steps
# --------------------------------------------------------------------------- #
print("\nPlotting spatial maps...")

ATM_SPATIAL_STEPS = [0, 19, 39, 59, 79]  # 6h, 120h, 240h, 360h, 480h
OCEAN_SPATIAL_STEPS = [0, 1, 2, 3]  # 120h, 240h, 360h, 480h

for var_key, var_info in VARIABLES.items():
    fme_name = var_info["fme_name"]
    e2s_name = var_info["e2s_name"]
    component = var_info["component"]

    ds_ref = get_ref_ds(component)

    if fme_name not in ds_ref.data_vars or e2s_name not in ds_e2s.data_vars:
        continue

    ref_vals = get_ref_series(ds_ref[fme_name])  # (steps, lat, lon)
    e2s_vals = get_e2s_series(ds_e2s[e2s_name], component)  # (steps, lat, lon)

    n_ref = ref_vals.shape[0]
    n_e2s = e2s_vals.shape[0]

    candidate_steps = OCEAN_SPATIAL_STEPS if component == "ocean" else ATM_SPATIAL_STEPS
    valid_steps = [s for s in candidate_steps if s < n_ref and s < n_e2s]
    if not valid_steps:
        print(f"  No valid steps for {var_key}, skipping spatial plot")
        continue

    n_rows = len(valid_steps)
    fig, axes = plt.subplots(n_rows, 3, figsize=(18, 4 * n_rows))
    if n_rows == 1:
        axes = axes[None, :]

    fig.suptitle(f"{var_info['label']}", fontsize=16, fontweight="bold", y=0.98)

    ref_lat = ds_ref.coords["lat"].values
    ref_lon = ds_ref.coords["lon"].values
    e2s_lat = ds_e2s.coords["lat"].values
    e2s_lon = ds_e2s.coords["lon"].values

    for row_idx, step in enumerate(valid_steps):
        ref_field = ref_vals[step]  # (lat, lon)
        e2s_field = e2s_vals[step]  # (lat, lon)

        # Align grids if lat ordering differs
        if ref_lat[0] < ref_lat[-1] and e2s_lat[0] > e2s_lat[-1]:
            e2s_field = e2s_field[::-1, :]
            plot_lat = ref_lat
        elif ref_lat[0] > ref_lat[-1] and e2s_lat[0] < e2s_lat[-1]:
            ref_field = ref_field[::-1, :]
            plot_lat = e2s_lat
        else:
            plot_lat = ref_lat

        diff = e2s_field - ref_field
        vmin = var_info["vmin"]
        vmax = var_info["vmax"]
        if component == "ocean":
            lead_h = (step + 1) * N_INNER_STEPS * 6  # each ocean step = 120h
        else:
            lead_h = (step + 1) * 6  # step 0 = first prediction = 6h lead

        # Reference
        ax = axes[row_idx, 0]
        im0 = ax.pcolormesh(
            ref_lon,
            plot_lat,
            ref_field,
            cmap=var_info["cmap"],
            vmin=vmin,
            vmax=vmax,
            shading="auto",
        )
        ax.set_ylabel(f"Step {step}\n({lead_h}h)", fontsize=11)
        if row_idx == 0:
            ax.set_title("Reference (FME)", fontsize=12)
        fig.colorbar(im0, ax=ax, shrink=0.8)

        # E2S
        ax = axes[row_idx, 1]
        im1 = ax.pcolormesh(
            e2s_lon,
            plot_lat,
            e2s_field,
            cmap=var_info["cmap"],
            vmin=vmin,
            vmax=vmax,
            shading="auto",
        )
        if row_idx == 0:
            ax.set_title("Earth2Studio", fontsize=12)
        fig.colorbar(im1, ax=ax, shrink=0.8)

        # Difference
        ax = axes[row_idx, 2]
        dmax = max(abs(np.nanmin(diff)), abs(np.nanmax(diff)))
        if dmax == 0:
            dmax = 1.0
        im2 = ax.pcolormesh(
            e2s_lon,
            plot_lat,
            diff,
            cmap="RdBu_r",
            vmin=-dmax,
            vmax=dmax,
            shading="auto",
        )
        if row_idx == 0:
            ax.set_title("Difference (E2S - Ref)", fontsize=12)
        fig.colorbar(im2, ax=ax, shrink=0.8)

        rmse = np.sqrt(np.nanmean(diff**2))
        mae = np.nanmean(np.abs(diff))
        ax.text(
            0.02,
            0.95,
            f"RMSE={rmse:.4g}\nMAE={mae:.4g}",
            transform=ax.transAxes,
            fontsize=8,
            verticalalignment="top",
            bbox=dict(boxstyle="round", facecolor="white", alpha=0.8),
        )

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    out_file = f"compare_{var_key}.png"
    plt.savefig(out_file, dpi=150, bbox_inches="tight")
    print(f"Saved {out_file}")
    plt.close()


# --------------------------------------------------------------------------- #
# Summary statistics
# --------------------------------------------------------------------------- #
print("\n" + "=" * 60)
print("SUMMARY: Per-variable comparison statistics")
print("=" * 60)

for var_key, var_info in VARIABLES.items():
    fme_name = var_info["fme_name"]
    e2s_name = var_info["e2s_name"]
    component = var_info["component"]

    ds_ref = get_ref_ds(component)

    if fme_name not in ds_ref.data_vars or e2s_name not in ds_e2s.data_vars:
        print(f"  {var_key:12s}: MISSING in one or both datasets")
        continue

    ref_vals = get_ref_series(ds_ref[fme_name])
    e2s_vals = get_e2s_series(ds_e2s[e2s_name], component)

    n_common = min(ref_vals.shape[0], e2s_vals.shape[0])
    ref_vals = ref_vals[:n_common]
    e2s_vals = e2s_vals[:n_common]

    diff = e2s_vals - ref_vals
    rmse = np.sqrt(np.nanmean(diff**2))
    mae = np.nanmean(np.abs(diff))
    max_abs = np.nanmax(np.abs(diff))
    close = np.allclose(ref_vals, e2s_vals, rtol=1e-5, atol=1e-5, equal_nan=True)

    print(
        f"  {var_key:12s}: RMSE={rmse:.6g}, MAE={mae:.6g}, "
        f"MaxAbsDiff={max_abs:.6g}, allclose={close}"
    )

print("=" * 60)

ds_ref_atm.close()
ds_ref_ocean.close()
ds_e2s.close()
print("\nComparison complete.")
compare_timeseries compare_zos compare_t2m compare_sst compare_sp

@NickGeneva NickGeneva marked this pull request as ready for review March 30, 2026 23:49
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 30, 2026

Greptile Summary

This PR adds SamudrACE, a coupled atmosphere-ocean prognostic model wrapper that integrates ACE2 (atmosphere) with Samudra (ocean) via FME's CoupledStepper. It includes a supporting SamudrACEData data source, SamudrACELexicon variable mapping, new ocean/coupled vocabulary entries in the base lexicon, and the samudrace optional extras group in pyproject.toml.

The implementation is well-structured and follows Earth2Studio conventions. Most of the prior review concerns have been addressed or acknowledged by the author.

Key observations:

  • The create_iterator / _default_generator path correctly maintains coupled atmosphere-ocean state (atm_state, ocean_state, atm_flux_accum, step_in_cycle) across 6h atmosphere steps and 5-day ocean cycles.
  • The __call__ single-step path always resets step_in_cycle=0, which unconditionally prescribes SST and never advances the ocean — acceptable for single-step use but undocumented.
  • The pyproject.toml samudrace extra pins fme>=2025.10.0 but does not declare a conflict with the ace2 extra (which requires unversioned fme).
  • The shared _forcing_time_index in _get_forcing_slice assumes all time-varying variables share the same time coordinate — an undocumented but reasonable assumption.
  • The netcdf4.py extension to units_map (ns/us/ms entries) is a clean, additive change.

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/documentation improvements that don't block correctness.

Prior P0/P1 concerns from earlier review rounds have been addressed (dead-code block resolved, HfFileSystem instance resolved). Remaining new findings are all P2: a missing conflict declaration in pyproject.toml, an undocumented assumption in the forcing time index, and a call docstring gap. None affect the primary inference path through create_iterator.

pyproject.toml (missing ace2/samudrace conflict), earth2studio/models/px/samudrace.py (call docstring gap and forcing time-index assumption)

Important Files Changed

Filename Overview
earth2studio/models/px/samudrace.py Core 954-line coupled atmosphere-ocean wrapper; logic is sound. Minor P2 concerns: __call__ always resets coupling state, and private FME attribute access (_input_only_names) creates fragile coupling.
earth2studio/data/samudrace.py IC data source; lazy xarray operations before cache cleanup (non-cache path) is a known concern tracked in previous threads.
earth2studio/lexicon/samudrace.py Clean bidirectional mapping between E2S and FME variable names; covers atmosphere (CM4 levels), ocean (CMIP6), and forcing variables.
pyproject.toml Adds samudrace extra with fme>=2025.10.0; missing conflict declaration with ace2 (which also requires fme, unversioned).
earth2studio/io/netcdf4.py Extends units_map with ns/us/ms time-unit entries needed for higher-resolution timedelta64 coordinates; safe additive change.

Reviews (2): Last reviewed commit: "Greptile" | Re-trigger Greptile

@NickGeneva
Copy link
Copy Markdown
Collaborator Author

@greptile-ai

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant